-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new Relax function to the batched model for evaluating query tokens over multiple time steps in parallel #156
Conversation
Thank you for the PR, @masahi! Which tvm should I use to run this? |
For now we need TVM from https://github.com/masahi/tvm/tree/vllm-cache-reconstruct. After apache/tvm#16376 is merged, I'll do a rebase.
|
ccbfb6e
to
09ef5b3
Compare
4001d61
to
7b67ba4
Compare
Opened #157 which uses the new Relax function from this PR to enable parallel-sampling eviction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great and should be sufficient for the speculative decoding with draft model.
By the way, is it still necessary to keep decode
after we have a good kernel for evaluate_multi_query
? Will there be performance loss if we run evaluate_multi_query
with one token from each sequence? If not, maybe we can just name this decode
. Maybe we can even retire prefill
if the kernel can specialize without degrading performance in the case where it doesn't need to read past KV from cache.
This is an interesting idea. I'd like to think that specialization allows perf advantages ( The comparison is a bit subtle since moving from single query to multiple ones involves switching entirely different kernel implementations (vllm to flash attention / flash infer). So perf can be affected by any number of reasons besides the increase in the number of query tokens. |
This PR reorganizes the artifact structure. We now have two separate types of directories to store the libs/weights/..., with one "prebuilt" directory which holds all the prebuilt libs and weights downloaded from internet, and other model directories that are generated by local builds. CLI and test scripts are updated accordingly for this change.
In speculative decoding and restoring KV cache entries for evicted parallel-sampling requests, we need to be able to compute logits over multiple tokens (time steps) while utilizing the KV cache for the past tensors. This is a hybrid of
prefill
anddecode
functions, in thatprefill
can compute logits over multiple tokens but doesn't read from KV cachedecode
works on one token at a time.I'm introducing a new function, tentatively called
evaluate_multi_query
, for this purpose.multi_query_decode
is also a good name.The changes in
run_llama_batched_vllm.py
shows a new request type and how the new function is meant to be used. There is no change underserve
yet since it is purely a model change. After we agree on the approach, I'll integrate this new function into the engine to complete my parallel-sampling work. @yelite needs this for speculative decoding.There is no attention kernel that reads from KV cache and operates on multiple queries, except FlashInfer which has
BatchedPrefillWithKVCache
. But we can emulate the behavior of such kernel by materializing past KV tensors from the cache, concat them with the present tensors, and running the standard prefill attention. This is not efficient but its correctness is much easier to verify. Until we integrate FlashInfer or Flash attention adds paged KV cache support, we can use this emulation.@sunggg @yelite @elvin-n